
clear all

global dat
global dat2
global g
global D
global k

bnewall=[];
g  = 532;   % num of groups
grp = 0;

randn('state',768795465)
N=0;

    for i = 1:g
  
      %%% Load X and Y %%%
    c1=['part1\data_part1_' int2str(i)];
    c2=['part2\data_part2_' int2str(i)];
    c3=['part3\data_part3_' int2str(i)];
    c4=['part4\data_part4_' int2str(i)];
    
    % add part 5 to part 36 for the school dummies.
    c5=['part5\data_part5_' int2str(i)];
    c6=['part6\data_part6_' int2str(i)];
    c7=['part7\data_part7_' int2str(i)];
    c8=['part8\data_part8_' int2str(i)];
    c9=['part9\data_part9_' int2str(i)];
    c10=['part10\data_part10_' int2str(i)];
    c11=['part11\data_part11_' int2str(i)];
    c12=['part12\data_part12_' int2str(i)];
    c13=['part13\data_part13_' int2str(i)];
    c14=['part14\data_part14_' int2str(i)];
    c15=['part15\data_part15_' int2str(i)];
    c16=['part16\data_part16_' int2str(i)];
    c17=['part17\data_part17_' int2str(i)];
    c18=['part18\data_part18_' int2str(i)];
    c19=['part19\data_part19_' int2str(i)];
    c20=['part20\data_part20_' int2str(i)];
    c21=['part21\data_part21_' int2str(i)];
    c22=['part22\data_part22_' int2str(i)];
    c23=['part23\data_part23_' int2str(i)];
    c24=['part24\data_part24_' int2str(i)];
    c25=['part25\data_part25_' int2str(i)];
    c26=['part26\data_part26_' int2str(i)];
    c27=['part27\data_part27_' int2str(i)];
    c28=['part28\data_part28_' int2str(i)];
    c29=['part29\data_part29_' int2str(i)];
    c30=['part30\data_part30_' int2str(i)];
    c31=['part31\data_part31_' int2str(i)];
    c32=['part32\data_part32_' int2str(i)];
    c33=['part33\data_part33_' int2str(i)];
    c34=['part34\data_part34_' int2str(i)];
    c35=['part35\data_part35_' int2str(i)];
    c36=['part36\data_part36_' int2str(i)];
    
    
    d1=load([c1 '.txt']);
    d2=load([c2 '.txt']);
    d3=load([c3 '.txt']);
    d4=load([c4 '.txt']);
    
    % add the school dummies 
    d5= load([ c5 '.txt']);  
    d6= load([ c6 '.txt']); 
    d7= load([ c7 '.txt']);  
    d8= load([ c8 '.txt']); 
    d9= load([ c9 '.txt']); 
    d10=load([ c10  '.txt']); 
    d11=load([ c11  '.txt']); 
    d12=load([ c12  '.txt']); 
    d13=load([ c13  '.txt']); 
    d14=load([ c14  '.txt']); 
    d15=load([ c15  '.txt']); 
    d16=load([ c16  '.txt']); 
    d17=load([ c17  '.txt']); 
    d18=load([ c18  '.txt']); 
    d19=load([ c19  '.txt']); 
    d20=load([ c20  '.txt']); 
    d21=load([ c21  '.txt']); 
    d22=load([ c22  '.txt']); 
    d23=load([ c23  '.txt']); 
    d24=load([ c24  '.txt']); 
    d25=load([ c25  '.txt']); 
    d26=load([ c26  '.txt']); 
    d27=load([ c27  '.txt']); 
    d28=load([ c28  '.txt']); 
    d29=load([ c29  '.txt']); 
    d30=load([ c30  '.txt']); 
    d31=load([ c31  '.txt']); 
    d32=load([ c32  '.txt']); 
    d33=load([ c33  '.txt']); 
    d34=load([ c34  '.txt']); 
    d35=load([ c35  '.txt']); 
    d36=load([ c36  '.txt']); 
       
    d4_x=d4(:,1:4);
    Xx=[d1 d2 d3 d4_x];
   
    X=[Xx(:,1:2) Xx(:,3:size(Xx,2)) Xx(:,1).^2/10]; 
    X  = sparse(X);
    
    Yy=d4(:,5);
   
    Y=2*Yy-1;
    Y  = sparse(Y);
 
    mr=length(Y); %% mr is the # of observations of group g 
    N=N+mr; 
    
    grp = grp+1;

    %%% Load the W,   %%%
    s=['weight\w_' int2str(i)];
    weight=load([s '.txt']);
    w11=weight(:,1);
    w12=weight(:,2);
    w13=weight(:,3);
    
    
    ww=sparse(w11,w12,w13,mr,mr);    
    ww=ww-diag(diag(ww));
    
    temp=full(sum(ww')');
    temp2=zeros(length(temp),1);
    for j=1:length(ww)
        if temp(j)~=0
            temp2(j)=1/temp(j);
        else
            temp2(j)=0;
        end
    end
    
    w=diag(temp2)*ww;
    Wr=sparse(w);
    
    
    clear temp
    clear temp2
    
        dat(grp).yt = Y;
        dat(grp).xt = X;
        dat(grp).wt = Wr;    
    end

 k=17; 
 % use the results from previous model as intial values.
tb1=[    1.0211 
   -0.0129 
    0.0213 
   -0.6114 
   -0.3261 
   -0.2875 
    0.0958 
   -0.1673 
    0.0433 
   -0.0722 
   -0.1002 
    0.1091 
    0.0674 
    0.0613 
    0.0758 
   -0.1538 
   -0.2999  ];

b1=[tb1' 0 -8.7787]'; 
%%%%%%%%%%%%%%%%%%% flag is the switch indicating whether the loop is going on or not
flag=1;
%%%%%%%%%%%%%%%%%%% count limit the number of iteration
count=1;
%%%%%%%%%%%%%%%%%%% SSEmatrx records the SSE
SSE_b=zeros(300,1);

while flag==1
    grp = 0;
    for i=1:g
            yt = dat(i).yt;
            wt = dat(i).wt;
            xt = dat(i).xt;
            
            euclid=1;
            
            mr=length(yt); 
            x0=ones(mr,1)/2;
            endog=zeros(mr,1);

                while euclid>1e-7
                   xx_new=tanh(xt*b1(1:k+0)+wt*x0*b1(k+0+1)+b1(k+0+2)*ones(mr,1));
                   euclid=max(sum((xx_new-x0).*(xx_new-x0)));
                   x0=xx_new;
                end
                
            endog=xx_new;
            %end
            grp = grp+1;
            dat2(grp).Mg=wt*endog;
    end
    
    %%%% MLE %%%%
    options=optimset('TolX',1e-7,'TolFun',1e-7,'MaxFunEvals',5e+3*count,'MaxIter',count*100,'GradObj','on','DerivativeCheck','off','Display','off');
    [bnew,fval(count),exitflag(count),output(count),grad,hessian] = fminunc(@mr_no_randomE2,b1,options);
    bnewall=[bnewall bnew];

    
    %%%%%%%%%%%%% Judge whether it converges
    SSE=(bnew-b1)'*(bnew-b1);
    if isnan(SSE)==1
        flag=0;
        elseif (SSE<=1e-4) % 
            flag=0;
            elseif count>1000
                flag=0;
    end
    
    %%%%%%%%%%%% Store SSE to SSEmatrx
    if flag==1
        SSE_b(count)=SSE;
        %update b
       
        b1=(bnew-b1)/(ceil(count/2))+b1;
        
        %check whether the number of iteration egxceeds the limit
        count=count+1
    end
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% Calculte s.e.

graf=zeros(k+0+2,k+0+2); 

ch2=0;
for i = 1:g
    yt = dat(i).yt;
    wt = dat(i).wt;
    xt = dat(i).xt;
    
    Mg = dat2(i).Mg;
    
    [mr junk] = size(xt);   
    
    indp_b1f=zeros(mr,1); 
    indp_b2f=zeros(mr,1);
    
    gra_gf=zeros(mr,k+0+2);  
    
        bbf=xt*bnew(1:k+0)+Mg*bnew(k+0+1)+bnew(k+0+2)*ones(mr,1); %unobs is 1 by D for each group.

        indp_b1f=indp_b1f+ 1./(1+exp(-2*bbf));
       
    temp3=indp_b1f;
    
        indp_b2f=indp_b2f+ 1./(1+exp(2*bbf));
      
        deri=4./(exp(2*bbf)+exp(-2*bbf)+2);
        temp1=diag(deri);
        temp2=inv(eye(mr,mr)-bnew(k+0+1)*temp1*wt)*temp1*[xt Mg ones(mr,1)];% remember to include the constant term.
        
        temp4=[xt Mg ones(mr,1)]+bnew(k+0+1)*wt*temp2;  % remember to include the constant term.
      
        gra_gf=gra_gf+temp4;
        
    graf1=2*diag((ones(mr,1)+yt)/2-temp3)*gra_gf;
    
    graf=graf+graf1'*graf1; 
        
    ch2=ch2+1;
    
end
    msscore = graf/N; 
    se=sqrt(diag(inv(msscore)/N));
    result=full([bnew se 2*(1-normcdf(abs(bnew./se)))])
    -2*fval
    

